knitr::opts_chunk$set(fig.align="center")
library(rstanarm)
library(tidyverse)
library(tidybayes)
library(modelr)
library(ggplot2)
library(magrittr)
library(emmeans)
library(bayesplot)
library(brms)
library(gganimate)
theme_set(theme_light())
source('helper_functions.R')
In our experiment, we used a visualization recommendation algorithm (composed of one search algorithm and one oracle algorithm) to generate visualizations for the user on one of two datasets. We then measured the user’s accuracy on two tasks: Find Extremum and Retrieve Value.
Given a search algorithm (bfs or dfs), an oracle (compassql or dziban), and a dataset (birdstrikes or movies), we would like to predict a user’s chance of answering the Find Extremum task and the Retrieve Value tasks correctly. In addition, we would like to know if the choice of search algorithm and oracle has any meaningful impact on a user’s accuracy for these two tasks, and if the participant’s group (student or professional) is associated with a difference in performance.
accuracy_data = read.csv('split_by_participant_groups/accuracy.csv')
accuracy_data$oracle = as.factor(accuracy_data$oracle)
accuracy_data$search = as.factor(accuracy_data$search)
accuracy_data$dataset = as.factor(accuracy_data$dataset)
models <- list()
draw_data <- list()
search_differences <- list()
oracle_differences <- list()
seed = 12
The prior (normal(0.8, .1)) was derived from pilot studies. It describes the distribution of probability of a correct answer for aany given task. Because our pilot study was small, we chose to aggregate these measurements (rather than deriving separate priors for each task) to minimize the effect of biases. We perform logistic regression in order to determine the probability of a correct answer under different conditions.
We can check our priors to make sure the model looks reasonable by pulling from our prior predictive distribution.
The lognormal family was selected to prevent our model from predicting less than zero elements exposed/interacted with.
prior <- brm(
bf(
accuracy ~ 0 + Intercept + oracle * search + dataset + task + participant_group + (1 | participant_id)
),
data = accuracy_data,
prior = c(prior(normal(0.8, .1), class = "b", coef = "Intercept"),
prior(normal(0, 2.5), class = "b")),
family = bernoulli(link = "logit"),
warmup = 500,
iter = 3000,
chains = 2,
cores = 2,
control = list(adapt_delta = 0.9),
seed = seed,
sample_prior = "only",
file = "models/prior_accuracy"
)
accuracy_data %>%
select(-accuracy) %>%
add_predicted_draws(prior, prediction = "accuracy", seed = seed) %>%
ggplot(aes(x = accuracy)) +
geom_density(fill = "gray", size = 0) +
scale_y_continuous(NULL, breaks = NULL) +
labs(subtitle = "Prior predictive distribution for task accuracy")
Now let’s make our actual model.
model <- brm(
bf(
accuracy ~ 0 + Intercept + oracle * search + dataset + task + participant_group + (1 | participant_id)
),
data = accuracy_data,
prior = c(prior(normal(0.8, .1), class = "b", coef = "Intercept"),
prior(normal(0, 2.5), class = "b")),
family = bernoulli(link = "logit"),
warmup = 500,
iter = 3000,
chains = 2,
cores = 2,
control = list(adapt_delta = 0.9),
seed = seed,
file = "models/accuracy"
)
In the summary table, we want to see Rhat values close to 1.0 and Bulk_ESS in the thousands.
summary(model)
## Family: bernoulli
## Links: mu = logit
## Formula: accuracy ~ 0 + Intercept + oracle * search + dataset + task + participant_group + (1 | participant_id)
## Data: accuracy_data (Number of observations: 144)
## Samples: 2 chains, each with iter = 3000; warmup = 500; thin = 1;
## total post-warmup samples = 5000
##
## Group-Level Effects:
## ~participant_id (Number of levels: 72)
## Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## sd(Intercept) 1.12 0.81 0.06 3.05 1.00 944 1970
##
## Population-Level Effects:
## Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS
## Intercept 0.84 0.10 0.64 1.04 1.00 7318
## oracledziban 2.23 0.95 0.55 4.27 1.00 4089
## searchdfs 2.26 0.95 0.64 4.34 1.00 3582
## datasetmovies 0.70 0.71 -0.55 2.25 1.00 3324
## task2.RetrieveValue 0.86 0.60 -0.30 2.08 1.00 5815
## participant_groupstudent -0.08 0.65 -1.27 1.32 1.00 3614
## oracledziban:searchdfs -2.93 1.31 -5.52 -0.38 1.00 3757
## Tail_ESS
## Intercept 3439
## oracledziban 3195
## searchdfs 2240
## datasetmovies 2516
## task2.RetrieveValue 3659
## participant_groupstudent 2491
## oracledziban:searchdfs 3647
##
## Samples were drawn using sampling(NUTS). For each parameter, Bulk_ESS
## and Tail_ESS are effective sample size measures, and Rhat is the potential
## scale reduction factor on split chains (at convergence, Rhat = 1).
Trace plots help us check whether there is evidence of non-convergence for model.
plot(model)
In our pairs plots, we want to make sure we don’t have highly correlated parameters (highly correlated parameters means that our model has difficulty differentiating the effect of such parameters).
pairs(
model,
pars = c("b_Intercept",
"b_datasetmovies",
"b_oracledziban",
"b_searchdfs",
"b_task2.RetrieveValue",
"b_participant_groupstudent"),
fixed = TRUE
)
A quick check of our posterior using posterior predictive checks.
pp_check(model, type = "dens_overlay", nsamples = 100)
A confusion matrix can be used to check our correct classification rate (a useful measure to see how well our model fits our data).
pred <- predict(model, type = "response")
pred <- if_else(pred[,1] > 0.5, 1, 0)
confusion_matrix <- table(pred, pull(accuracy_data, accuracy))
confusion_matrix
##
## pred 0 1
## 1 10 134
Visualization of parameter effects via draws from our model posterior. The thicker line represents the 95% credible interval, while the thinner, longer line represents the 50% credible interval.
draw_data <- accuracy_data %>%
add_fitted_draws(model, seed = seed, re_formula = NA, scale = "response") %>%
group_by(search, oracle, dataset, task, .draw)
plot_data <- draw_data
plot_data$oracle<- gsub('compassql', 'CompassQL', plot_data$oracle)
plot_data$oracle<- gsub('dziban', 'Dziban', plot_data$oracle)
plot_data$search<- gsub('bfs', 'BFS', plot_data$search)
plot_data$search<- gsub('dfs', 'DFS', plot_data$search)
plot_data$dataset<- gsub('birdstrikes', 'Birdstrikes', plot_data$dataset)
plot_data$dataset<- gsub('movies', 'Movies', plot_data$dataset)
plot_data$Dataset<- plot_data$dataset
plot_data$condition <- paste(plot_data$oracle, plot_data$search, sep="\n")
draw_plot <- posterior_draws_plot(plot_data, "Dataset", TRUE, "Predicted Average Accuracy (p_correct)", "Oracle/Search Combination") + scale_alpha(guide = 'none') + coord_cartesian(xlim = c(0.4, 1)) +xlab("Predicted Average Accuracy (p_correct)")
draw_plot
Since the credible intervals on our plot overlap, we can use mean_qi to get the numeric boundaries for the different intervals.
fit_info <- draw_data %>% group_by(search, oracle, dataset, task) %>% mean_qi(.value, .width = c(.95, .5))
fit_info
## # A tibble: 32 x 10
## # Groups: search, oracle, dataset [8]
## search oracle dataset task .value .lower .upper .width .point .interval
## <fct> <fct> <fct> <chr> <dbl> <dbl> <dbl> <dbl> <chr> <chr>
## 1 bfs compas… birdstr… 1. Find… 0.681 0.436 0.874 0.95 mean qi
## 2 bfs compas… birdstr… 2. Retr… 0.818 0.572 0.958 0.95 mean qi
## 3 bfs compas… movies 1. Find… 0.792 0.528 0.963 0.95 mean qi
## 4 bfs compas… movies 2. Retr… 0.890 0.694 0.987 0.95 mean qi
## 5 bfs dziban birdstr… 1. Find… 0.933 0.772 0.995 0.95 mean qi
## 6 bfs dziban birdstr… 2. Retr… 0.967 0.873 0.998 0.95 mean qi
## 7 bfs dziban movies 1. Find… 0.961 0.850 0.998 0.95 mean qi
## 8 bfs dziban movies 2. Retr… 0.982 0.926 0.999 0.95 mean qi
## 9 dfs compas… birdstr… 1. Find… 0.935 0.783 0.995 0.95 mean qi
## 10 dfs compas… birdstr… 2. Retr… 0.968 0.878 0.998 0.95 mean qi
## # … with 22 more rows
## Saving 7 x 5 in image
We’d now like to see the difference in average accuracy between levels of search, oracle, and participant group for each task.
predictive_data <- accuracy_data %>%
add_fitted_draws(model, seed = seed, re_formula = NA, scale = "response")
Differences in search algorithms:
search_differences <- expected_diff_in_mean_plot(predictive_data, "search", "Difference in Mean Accuracy (p_correct)", "Task", "dataset")
## `summarise()` regrouping output by 'search', 'task', 'dataset' (override with `.groups` argument)
search_differences$plot
We can double-check the boundaries of the credible intervals to be sure whether or not the interval contains zero. If the 95% credible interval does not contain zero, we are 95% confident that there is a nonzero difference in mean accuracy between these two conditions.
search_differences$intervals
## # A tibble: 8 x 9
## # Groups: search, dataset [2]
## search dataset task difference .lower .upper .width .point .interval
## <chr> <fct> <fct> <dbl> <dbl> <dbl> <dbl> <chr> <chr>
## 1 bfs - … birdstri… 1. Find… -0.101 -0.221 0.0238 0.95 mean qi
## 2 bfs - … birdstri… 2. Retr… -0.0621 -0.171 0.00960 0.95 mean qi
## 3 bfs - … movies 1. Find… -0.0687 -0.187 0.0106 0.95 mean qi
## 4 bfs - … movies 2. Retr… -0.0387 -0.119 0.00431 0.95 mean qi
## 5 bfs - … birdstri… 1. Find… -0.101 -0.140 -0.0635 0.5 mean qi
## 6 bfs - … birdstri… 2. Retr… -0.0621 -0.0858 -0.0306 0.5 mean qi
## 7 bfs - … movies 1. Find… -0.0687 -0.0975 -0.0338 0.5 mean qi
## 8 bfs - … movies 2. Retr… -0.0387 -0.0552 -0.0155 0.5 mean qi
Differences in oracle:
oracle_differences <- expected_diff_in_mean_plot(predictive_data, "oracle", "Difference in Mean Accuracy (p_correct)", "Task", "dataset")
## `summarise()` regrouping output by 'oracle', 'task', 'dataset' (override with `.groups` argument)
oracle_differences$plot
We can double-check the boundaries of the credible intervals to be sure whether or not the interval contains zero. If the 95% credible interval does not contain zero, we are 95% confident that there is a nonzero difference in mean accuracy between these two conditions.
oracle_differences$intervals
## # A tibble: 8 x 9
## # Groups: oracle, dataset [2]
## oracle dataset task difference .lower .upper .width .point .interval
## <chr> <fct> <fct> <dbl> <dbl> <dbl> <dbl> <chr> <chr>
## 1 dziban -… birdstr… 1. Find… 0.0990 -0.0241 0.215 0.95 mean qi
## 2 dziban -… birdstr… 2. Retr… 0.0607 -0.00904 0.166 0.95 mean qi
## 3 dziban -… movies 1. Find… 0.0672 -0.0114 0.180 0.95 mean qi
## 4 dziban -… movies 2. Retr… 0.0379 -0.00410 0.120 0.95 mean qi
## 5 dziban -… birdstr… 1. Find… 0.0990 0.0616 0.138 0.5 mean qi
## 6 dziban -… birdstr… 2. Retr… 0.0607 0.0305 0.0850 0.5 mean qi
## 7 dziban -… movies 1. Find… 0.0672 0.0321 0.0956 0.5 mean qi
## 8 dziban -… movies 2. Retr… 0.0379 0.0151 0.0538 0.5 mean qi
Differences in participant group (student vs professional):
participant_group_differences <- expected_diff_in_mean_plot(predictive_data, "participant_group", "Difference in Mean Accuracy (p_correct)", "Task", NULL)
## `summarise()` regrouping output by 'participant_group', 'task' (override with `.groups` argument)
participant_group_differences$plot
We can double-check the boundaries of the credible intervals to be sure whether or not the interval contains zero. If the 95% credible interval does not contain zero, we are 95% confident that there is a nonzero difference in mean accuracy between these two conditions.
participant_group_differences$intervals
## # A tibble: 4 x 8
## # Groups: participant_group [1]
## participant_group task difference .lower .upper .width .point .interval
## <chr> <fct> <dbl> <dbl> <dbl> <dbl> <chr> <chr>
## 1 student - profess… 1. Find … -0.0126 -0.125 0.0859 0.95 mean qi
## 2 student - profess… 2. Retri… -0.00530 -0.0701 0.0605 0.95 mean qi
## 3 student - profess… 1. Find … -0.0126 -0.0467 0.0242 0.5 mean qi
## 4 student - profess… 2. Retri… -0.00530 -0.0237 0.0130 0.5 mean qi
participant_group_differences_dataset <- expected_diff_in_mean_plot(predictive_data, "participant_group", "Difference in Mean Accuracy (p_correct)", "Task", "dataset")
## `summarise()` regrouping output by 'participant_group', 'task', 'dataset' (override with `.groups` argument)
participant_group_differences_dataset$plot